{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "efdbb171-ebf1-4465-871a-6add225b8313",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8ae86b02c05d4d27a353dcb88a484249",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'那太好了！请分享一下你为什么开心？'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "device = \"cuda\" # the device to load the model onto\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"/home/lyz/hf-models/Qwen/Qwen1.5-4B-Chat/\",\n",
    "    torch_dtype=\"auto\",\n",
    "    device_map=\"auto\"\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"/home/lyz/hf-models/Qwen/Qwen1.5-4B-Chat/\")\n",
    "\n",
    "prompt = \"我今天很开心\"\n",
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "    {\"role\": \"user\", \"content\": prompt}\n",
    "]\n",
    "text = tokenizer.apply_chat_template(\n",
    "    messages,\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True\n",
    ")\n",
    "model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
    "\n",
    "generated_ids = model.generate(\n",
    "    model_inputs.input_ids,\n",
    "    max_new_tokens=512\n",
    ")\n",
    "generated_ids = [\n",
    "    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
    "]\n",
    "\n",
    "response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f248c8e5-67a5-4b13-81f9-41e8521ad5b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据集，这里是直接联网读取，也可以通过下载文件，再读取\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n",
    "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n",
    "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a6d1823c-e857-48d4-b82c-44ea52206640",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>还有双鸭山到淮阴的汽车票吗13号的</td>\n",
       "      <td>Travel-Query</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>从这里怎么回家</td>\n",
       "      <td>Travel-Query</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>随便播放一首专辑阁楼里的佛里的歌</td>\n",
       "      <td>Music-Play</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>给看一下墓王之王嘛</td>\n",
       "      <td>FilmTele-Play</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>我想看挑战两把s686打突变团竞的游戏视频</td>\n",
       "      <td>Video-Play</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12095</th>\n",
       "      <td>一千六百五十三加三千一百六十五点六五等于几</td>\n",
       "      <td>Calendar-Query</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12096</th>\n",
       "      <td>稍小点客厅空调风速</td>\n",
       "      <td>HomeAppliance-Control</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12097</th>\n",
       "      <td>黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席</td>\n",
       "      <td>Radio-Listen</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12098</th>\n",
       "      <td>百事盖世群星星光演唱会有谁</td>\n",
       "      <td>Video-Play</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12099</th>\n",
       "      <td>下周一视频会议的闹钟帮我开开</td>\n",
       "      <td>Alarm-Update</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>12100 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                 0                      1\n",
       "0                还有双鸭山到淮阴的汽车票吗13号的           Travel-Query\n",
       "1                          从这里怎么回家           Travel-Query\n",
       "2                 随便播放一首专辑阁楼里的佛里的歌             Music-Play\n",
       "3                        给看一下墓王之王嘛          FilmTele-Play\n",
       "4            我想看挑战两把s686打突变团竞的游戏视频             Video-Play\n",
       "...                            ...                    ...\n",
       "12095        一千六百五十三加三千一百六十五点六五等于几         Calendar-Query\n",
       "12096                    稍小点客厅空调风速  HomeAppliance-Control\n",
       "12097  黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席           Radio-Listen\n",
       "12098                百事盖世群星星光演唱会有谁             Video-Play\n",
       "12099               下周一视频会议的闹钟帮我开开           Alarm-Update\n",
       "\n",
       "[12100 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "733edea2-647a-48e2-95e9-8e79d88d96fa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Travel-Query Music-Play FilmTele-Play Video-Play Radio-Listen HomeAppliance-Control Weather-Query Alarm-Update Calendar-Query TVProgram-Play Audio-Play Other'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "' '.join(train_data[1].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0c90b916-a8b5-4768-a636-0f1155e72837",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 359 ms, sys: 0 ns, total: 359 ms\n",
      "Wall time: 365 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Music-Play'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "prompt = '''识别句子的意图：播放周杰伦的歌曲\n",
    "待选类别：Travel-Query Music-Play FilmTele-Play Video-Play Radio-Listen HomeAppliance-Control Weather-Query Alarm-Update Calendar-Query TVProgram-Play Audio-Play Other\n",
    "只需要输出类别（从待选类别中选一个），不要其他输出。\n",
    "'''\n",
    "messages = [\n",
    "    {\"role\": \"user\", \"content\": prompt}\n",
    "]\n",
    "text = tokenizer.apply_chat_template(\n",
    "    messages,\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True\n",
    ")\n",
    "model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
    "\n",
    "generated_ids = model.generate(\n",
    "    model_inputs.input_ids,\n",
    "    max_new_tokens=512,\n",
    "    \n",
    ")\n",
    "generated_ids = [\n",
    "    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
    "]\n",
    "\n",
    "response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6dfe163f-55dd-4cf2-9afb-3c7563d03f79",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "90984788-5411-4cb1-9369-4e7af42b7f65",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_36021/2977000300.py:2: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for train_text in tqdm_notebook(test_data[0].values):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3fd75ff95caa4c378a14b06ef2cf0a1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_label = []\n",
    "for train_text in tqdm_notebook(test_data[0].values):\n",
    "    prompt = f'''识别句子的意图：{train_text}\n",
    "    待选类别：Travel-Query Music-Play FilmTele-Play Video-Play Radio-Listen HomeAppliance-Control Weather-Query Alarm-Update Calendar-Query TVProgram-Play Audio-Play Other\n",
    "    只需要输出类别（从待选类别中选一个），不要其他输出。\n",
    "    '''\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": \"你是一个擅长意图分类的自然语言专家。\"},\n",
    "        {\"role\": \"user\", \"content\": prompt}\n",
    "    ]\n",
    "    text = tokenizer.apply_chat_template(\n",
    "        messages,\n",
    "        tokenize=False,\n",
    "        add_generation_prompt=True\n",
    "    )\n",
    "    model_inputs = tokenizer([text], return_tensors=\"pt\").to(device)\n",
    "    \n",
    "    generated_ids = model.generate(\n",
    "        model_inputs.input_ids,\n",
    "        max_new_tokens=512,\n",
    "        \n",
    "    )\n",
    "    generated_ids = [\n",
    "        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
    "    ]\n",
    "    \n",
    "    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "    pred_label += [response]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0d2e7cec-7ad6-4f86-9b05-054afff1a6e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame({\n",
    "    'ID': range(1, len(pred_label) + 1),\n",
    "    'Target': pred_label,\n",
    "}).to_csv('nlp_submit.csv', index=None)\n",
    "\n",
    "# 提交一下吧~"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e14d3e0a-6837-48ad-b971-61b72f9ea397",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['TVProgram-Play',\n",
       " 'HomeAppliance-Control',\n",
       " 'Video-Play',\n",
       " 'Travel-Query',\n",
       " 'HomeAppliance-Control',\n",
       " 'TVProgram-Play',\n",
       " 'FilmTele-Play',\n",
       " 'Music-Play',\n",
       " 'Weather-Query',\n",
       " 'Film Tele-Play',\n",
       " 'Travel-Query',\n",
       " 'Music-Play',\n",
       " 'Travel-Query',\n",
       " 'TVProgram',\n",
       " 'Music-Play',\n",
       " 'Alarm-Update',\n",
       " 'Music-Play',\n",
       " 'Weather-Query',\n",
       " 'Weather-Query',\n",
       " 'FilmTele-Play']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_label[:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d864732d-79bf-4a38-a0ec-c41ca2a4bcf4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
